import matplotlib.pyplot as plt
from matplotlib import rcParams
import numpy as np
import scipy.optimize as sco
import matplotlib.gridspec as gridspec
import pymatgen.io.vasp.outputs as vo
import pandas
import glob

def points_per_band(data):
    length = 0
    for i in range(len(data)-1):
        if data[i+1,0] < data[i,0]:
            length = i
            break
    return length + 1

def plot_bands(ax, fname, color='c', bandpathfile = "bandstructure_kpoints_hex_dense", ymin = -14, ymax = 8):
    p = vo.Poscar.from_file(fname + "/CONTCAR").structure.lattice
    al = p.a
    bl = p.b
    cl = p.c

    data = np.loadtxt(fname + "/QUASI.bandstr")
    bpoints = np.loadtxt(fname + "/BPOINTS.bandstr")
    vrun = vo.Vasprun(fname + "/vasprun.xml", parse_potcar_file=False)
    nelect = int(vrun.parameters['NELECT'])
    
    rlen = points_per_band(data)
    point = [b[1] for b in bpoints]
    pointlabel = pandas.read_csv(fname + "/" + bandpathfile , skiprows=3, sep = '\s+', header = None).values[:,4]
    print(point)
    ax.set_xticks(point)
    ax.set_xticklabels(pointlabel)
    ax.set_xlim([np.min(data[:,0]), np.max(data[:,0])])
    ax.axhline(0.0, color='k', linewidth=0.2)
    ax.axhline(np.min(data[rlen*nelect:rlen*(nelect+1), 1]), ls=':', lw=.5, c='k')
    for p in point:
        ax.axvline(p, color='k', linewidth=0.2)

    for i in range(int(len(data)/rlen)):
        if(np.min(data[rlen*i:rlen*(i+1),1]) > ymax):
            continue
        if(np.max(data[rlen*i:rlen*(i+1),1]) < ymin):
            continue
        ax.plot(data[rlen*i:rlen*(i+1),0],data[rlen*i:rlen*(i+1),1], linewidth=2, color=color)
    ax.set_ylim([ymin,ymax])


if __name__ == "__main__":
    rcParams['font.size'] = 15
    rcParams['font.family'] = 'DejaVu Sans'
    rcParams['mathtext.fontset'] = 'dejavusans'


    fig = plt.figure(figsize = (15,10))
    gs = gridspec.GridSpec(2, 10)
    sp = fig.add_subplot(gs[0,7:])
    ac = fig.add_subplot(gs[1,7:])

    Si_C = "si-c-bands"
    Si_2H = "si-2h-bands"
    Ge_C = "ge-c-bands"
    Ge_2H = "ge-2h-bands"

    ax_gec = plt.subplot(gs[1, :2])
    plot_bands(ax_gec, Ge_C, color='#5aa800', ymin=-1, ymax=3)
    ax_gec.set_yticks([-1,0,1,2,3])
    ax_ge2h = plt.subplot(gs[1, 2:6])
    plot_bands(ax_ge2h, Ge_2H, color='#5aa800', ymin=-1, ymax=3)
    ax_ge2h.set_yticks([-1,0,1,2,3])
    ax_ge2h.set_yticklabels([])

    ax_sic = plt.subplot(gs[0, :2])
    plot_bands(ax_sic, Si_C, color='#0098e9', ymin=-1, ymax=3)
    ax_si2h = plt.subplot(gs[0, 2:6])
    plot_bands(ax_si2h, Si_2H, color='#0098e9', ymin=-1, ymax=3)
    ax_si2h.set_yticks([-1,0,1,2,3])
    ax_si2h.set_yticklabels([])

    ax_sic.set_ylabel("Energy (eV)")
    ax_gec.set_ylabel("Energy (eV)")

    ax_gec.text(0.2, 2.3, "cub-Ge")
    ax_sic.text(0.2, 1.9, "cub-Si")   

    ax_ge2h.text(0.175, 2.3, "hex-Ge")
    ax_si2h.text(0.175, 1.9, "hex-Si")
    #ax_si2h.arrow(0.077, 0.0, 0.155, 1.1, length_includes_head=True, head_width=0.02, head_length=0.1, fc='k', ec='k')
    ax_si2h.annotate(s='', xy=(0.155+0.077,1.1), xytext=(0.077,0.0), arrowprops=dict(arrowstyle='->'))

    # Plot energies
    data = np.loadtxt("Band_energies",)
    stat = np.loadtxt("alloy_stat")
    
    x = []
    y = []
    yerr = []
    for i in range(9):
        st = np.array(stat[stat[:,0] == i])
        da = np.array(data[data[:,0] == i])
        bm = [np.nansum(da[:,j]*st[:,2]) / np.sum(st[np.isfinite(da[:,j]),2])
              for j in range(2,len(da[0,:])) ]
        bs = [np.sqrt( np.nansum(st[:,2] * (bm[j-2] - da[:,j])**2 / st[0,3] ))
              for j in range(2,len(da[0,:])) ]
        x.append(i/8)
        y.append(bm)
        yerr.append(bs)
    

    x = np.array(x)
    y = np.array(y)

    yerr = np.array(yerr)
    label = ["$\Gamma_{7v-}^+$", "$\Gamma_{7v+}^+$", "$\Gamma_{9v}^+$", "$\Gamma_{8c}^-$", "$\Gamma_{7c}^-$", "M", "L", "U","$\Gamma$"]

    color = ["", "", "", "", "", "darkgreen", "blue", "red", "orange", "cyan"]

    for i in [5,6,7,8]: #range(3,len(y[0,:])):
        coeff = np.polyfit(1-x, y[:,i], 2)
        pol = np.poly1d(coeff)
        sp.plot(1-x, y[:,i], marker = "o", linestyle='', color = color[i])
        sp.plot(1-x, pol(1-x), marker = "", color = color[i], label = label[i])

    sp.legend()
    sp.set_xlabel("Ge content")
    sp.set_ylabel("Energy (eV)")


    base = "Optics"
    folders = sorted(glob.glob(base + "/POSCAR_*"))
    statat = [[] for i in range(4)]
    data = [[] for i in range(4)]
    for folder in sorted(glob.glob(base + "/POSCAR_*")):
        if(int(folder.split("_")[-2]) > 3):
            continue
        conc = list(map(int,folder.split("/")[-1].split("_")[-2:]))
        st = stat[np.logical_and(stat[:,0] == conc[0], stat[:,1] == conc[1])].flatten()
        lt = np.loadtxt(folder + "/radiative_lifetime_QS.dat")
        data[conc[0]].append(lt)
        statat[conc[0]].append(stat[np.logical_and(stat[:,0] == conc[0],stat[:,1] == conc[1])][0])
        best = [0,0,2,5]
    colors = ["red", "green", "blue", "orange", '#0098e9']
    for i in range(4):
        ac.semilogy(data[i][0][:,0], data[i][best[i]][:,1], label ="Si$_{}$Ge$_{}$".format(i,8-i), c=colors[i], ls ="-")

    data_GaAs = np.loadtxt("radiative_lifetime_QS_GaAs.dat")
    ac.semilogy(data_GaAs[:,0], data_GaAs[:,1], label ="GaAs", c=colors[4], ls ="-")
#        lifetime = np.sum([d[:,1] * w[2] / w[3]  for d,w  in zip(data[i], statat[i])], axis=0)
#        if(len(statat[i]) > 1):
#            ac.plot(data[i][0][:,0], lifetime, label ="Si$_{}$Ge$_{}$".format(i,8-i), c=colors[i], ls =":")

    ac.set_ylim([1e-9, 1e-4])
    #ac.set_xlim([0.2,1])
    ac.legend(loc='upper right')
    ac.set_xlabel("Temperature (K)")
    ac.set_ylabel(r"Radiative lifetime $\tau$ (s)")

    fig.text(.0, .95, "a)")
    fig.text(.0, .45, "b)")
    fig.text(.635, .95, "c)")
    fig.text(.635, .45, "d)")

    gs.update(left=0.05, right=0.97, top=0.97, bottom=0.06, wspace=.2, hspace=.2)
    plt.savefig("si-ge-bandstructure-gap-lifetime.pdf")
    #plt.show()
